/*
 * Copyright (c) 2009-2012, Peter Abeles. All Rights Reserved.
 *
 * This file is part of Efficient Java Matrix Library (EJML).
 *
 * EJML is free software: you can redistribute it and/or modify
 * it under the terms of the GNU Lesser General Public License as
 * published by the Free Software Foundation, either version 3
 * of the License, or (at your option) any later version.
 *
 * EJML is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU Lesser General Public License for more details.
 *
 * You should have received a copy of the GNU Lesser General Public
 * License along with EJML.  If not, see <http://www.gnu.org/licenses/>.
 */

package org.ejml.alg.dense.mult;

import org.ejml.alg.generic.CodeGeneratorMisc;

import java.io.FileNotFoundException;
import java.io.PrintStream;


/**
 * <p>
 * This class generates code for various matrix matrix multiplication operations.  The code associated
 * with these operators is often only slightly different from each other.  So to remove some
 * of the tediousness of writing and maintaining it is autogenerated.
 * <p>
 * <p>
 * To create {@link MatrixMatrixMult} simply run this application and copy it to the appropriate location.
 * </p>
 *
 * @author Peter Abeles
 */
public class GeneratorMatrixMatrixMult {

    PrintStream stream;

    public GeneratorMatrixMatrixMult( String fileName ) throws FileNotFoundException {
        stream = new PrintStream(fileName);
    }

    public void createClass() {
        String preamble = CodeGeneratorMisc.COPYRIGHT +
                "\n" +
                "package org.ejml.alg.dense.mult;\n" +
                "\n" +
                "import org.ejml.data.RowD1Matrix64F;\n"+
                "\n" +
                "/**\n" +
                " * <p>\n" +
                " * This class contains various types of matrix matrix multiplication operations for {@link RowD1Matrix64F}.\n" +
                " * </p>\n" +
                " * <p>\n" +
                " * Two algorithms that are equivalent can often have very different runtime performance.\n" +
                " * This is because of how modern computers uses fast memory caches to speed up reading/writing to data.\n" +
                " * Depending on the order in which variables are processed different algorithms can run much faster than others,\n" +
                " * even if the number of operations is the same.\n" +
                " * </p>\n" +
                " *\n" +
                " * <p>\n" +
                " * Algorithms that are labeled as 'reorder' are designed to avoid caching jumping issues, some times at the cost\n" +
                " * of increasing the number of operations.  This is important for large matrices.  The straight forward \n" +
                " * implementation seems to be faster for small matrices.\n" +
                " * </p>\n" +
                " * \n" +
                " * <p>\n" +
                " * Algorithms that are labeled as 'aux' use an auxiliary array of length n.  This array is used to create\n" +
                " * a copy of an out of sequence column vector that is referenced several times.  This reduces the number\n" +
                " * of cache misses.  If the 'aux' parameter passed in is null then the array is declared internally.\n" +
                " * </p>\n" +
                " *\n" +
                " * <p>\n" +
                " * Typically the straight forward implementation runs about 30% faster on smaller matrices and\n" +
                " * about 5 times slower on larger matrices.  This is all computer architecture and matrix shape/size specific.\n" +
                " * </p>\n" +
                " * \n" +
                " * <p>\n" +
                " * <center>******** IMPORTANT **********</center>\n" +
                " * This class was auto generated using {@link org.ejml.alg.dense.mult.CodeGeneratorMatrixMatrixMult}\n" +
                " * If this code needs to be modified, please modify {@link org.ejml.alg.dense.mult.CodeGeneratorMatrixMatrixMult} instead\n" +
                " * and regenerate the code by running that.\n" +
                " * </p>\n" +
                " * \n" +
                " * @author Peter Abeles\n" +
                " */\n"+
                "public class MatrixMatrixMult {\n";

        stream.print(preamble);

        for( int i = 0; i < 2; i++ ) {
            boolean alpha = i == 1;
            for( int j = 0; j < 2; j++ ) {
                boolean add = j == 1;
                printMult_reroder(alpha,add);
                stream.print("\n");
                printMult_small(alpha,add);
                stream.print("\n");
                printMult_aux(alpha,add);
                stream.print("\n");
                printMultTransA_reorder(alpha,add);
                stream.print("\n");
                printMultTransA_small(alpha,add);
                stream.print("\n");
                printMultTransAB(alpha,add);
                stream.print("\n");
                printMultTransAB_aux(alpha,add);
                stream.print("\n");
                printMultTransB(alpha,add);
                stream.print("\n");
            }
        }
        stream.print("}\n");
    }

    private String makeBoundsCheck(boolean tranA, boolean tranB, String auxLength)
    {
        String a_numCols = tranA ? "a.numRows" : "a.numCols";
        String a_numRows = tranA ? "a.numCols" : "a.numRows";
        String b_numCols = tranB ? "b.numRows" : "b.numCols";
        String b_numRows = tranB ? "b.numCols" : "b.numRows";

        String ret =
                        "        if( a == c || b == c )\n" +
                        "            throw new IllegalArgumentException(\"Neither 'a' or 'b' can be the same matrix as 'c'\");\n"+
                        "        else if( "+a_numCols+" != "+b_numRows+" ) {\n" +
                        "            throw new MatrixDimensionException(\"The 'a' and 'b' matrices do not have compatible dimensions\");\n" +
                        "        } else if( "+a_numRows+" != c.numRows || "+b_numCols+" != c.numCols ) {\n" +
                        "            throw new MatrixDimensionException(\"The results matrix does not have the desired dimensions\");\n" +
                        "        }\n" +
                        "\n";

        if( auxLength != null ) {
            ret += "        if( aux == null ) aux = new double[ "+auxLength+" ];\n\n";
        }

        return ret;
    }

    private String makeComment( String nameOp , boolean hasAlpha )
    {
        String a = hasAlpha ? "double, " : "";
        String inputs = "("+a+" org.ejml.data.RowD1Matrix64F, org.ejml.data.RowD1Matrix64F, org.ejml.data.RowD1Matrix64F)";


        String ret =
                "    /**\n" +
                "     * @see org.ejml.ops.CommonOps#"+nameOp+inputs+"\n" +
                "     */\n";
        return ret;
    }

    private String makeHeader(String nameOp, String variant,
                              boolean add, boolean hasAlpha, boolean hasAux,
                              boolean tranA, boolean tranB)
    {
        if( add ) nameOp += "Add";

        // make the op name
        if( tranA && tranB ) {
            nameOp += "TransAB";
        } else if( tranA ) {
            nameOp += "TransA";
        } else if( tranB ) {
            nameOp += "TransB";
        }

        String ret = makeComment(nameOp,hasAlpha)+
                     "    public static void "+nameOp;

        if( variant != null ) ret += "_"+variant+"( ";
        else ret += "( ";

        if( hasAlpha ) ret += "double alpha , ";

        if( hasAux ) {
            ret += "RowD1Matrix64F a , RowD1Matrix64F b , RowD1Matrix64F c , double []aux )\n";
        } else {
            ret += "RowD1Matrix64F a , RowD1Matrix64F b , RowD1Matrix64F c )\n";
        }

        ret += "    {\n";

        return ret;
    }

    public void printMult_reroder( boolean alpha , boolean add ) {
        String header,valLine;

        header = makeHeader("mult","reorder",add,alpha, false, false,false);

        if( alpha ) {
            valLine = "valA = alpha*a.get(indexA++);\n";
        } else {
            valLine = "valA = a.get(indexA++);\n";
        }

        String assignment = add ? "plus" : "set";

        String foo =
                header + makeBoundsCheck(false,false, null)+
                        "        double valA;\n"+
                        "        int indexCbase= 0;\n" +
                        "        int endOfKLoop = b.numRows*b.numCols;\n"+
                        "\n" +
                        "        for( int i = 0; i < a.numRows; i++ ) {\n" +
                        "            int indexA = i*a.numCols;\n" +
                        "\n"+
                        "            // need to assign c.data to a value initially\n" +
                        "            int indexB = 0;\n" +
                        "            int indexC = indexCbase;\n" +
                        "            int end = indexB + b.numCols;\n" +
                        "\n" +
                        "            "+valLine +
                        "\n" +
                        "            while( indexB < end ) {\n" +
                        "                c."+assignment+"(indexC++ , valA*b.get(indexB++));\n" +
                        "            }\n" +
                        "\n" +
                        "            // now add to it\n"+
                        "            while( indexB != endOfKLoop ) { // k loop\n"+
                        "                indexC = indexCbase;\n" +
                        "                end = indexB + b.numCols;\n" +
                        "\n" +
                        "                "+valLine+
                        "\n" +
                        "                while( indexB < end ) { // j loop\n" +
                        "                    c.plus(indexC++ , valA*b.get(indexB++));\n" +
                        "                }\n" +
                        "            }\n" +
                        "            indexCbase += c.numCols;\n" +
                        "        }\n" +
                        "    }\n";

        stream.print(foo);
    }

    public void printMult_small( boolean alpha , boolean add ) {
        String header,valLine;

        header = makeHeader("mult","small",add,alpha, false, false,false);

        String assignment = add ? "plus" : "set";

        if( alpha ) {
            valLine = "                c."+assignment+"( cIndex++ , alpha*total );\n";
        } else {
            valLine = "                c."+assignment+"( cIndex++ , total );\n";
        }

        String foo =
                header + makeBoundsCheck(false,false, null)+
                        "        int aIndexStart = 0;\n" +
                        "        int cIndex = 0;\n" +
                        "\n" +
                        "        for( int i = 0; i < a.numRows; i++ ) {\n" +
                        "            for( int j = 0; j < b.numCols; j++ ) {\n" +
                        "                double total = 0;\n" +
                        "\n" +
                        "                int indexA = aIndexStart;\n" +
                        "                int indexB = j;\n" +
                        "                int end = indexA + b.numRows;\n" +
                        "                while( indexA < end ) {\n" +
                        "                    total += a.get(indexA++) * b.get(indexB);\n" +
                        "                    indexB += b.numCols;\n" +
                        "                }\n" +
                        "\n" +
                        valLine +
                        "            }\n" +
                        "            aIndexStart += a.numCols;\n" +
                        "        }\n" +
                        "    }\n";
        stream.print(foo);
    }

    public void printMult_aux( boolean alpha , boolean add ) {
        String header,valLine;

        header = makeHeader("mult","aux",add,alpha, true, false,false);

        String assignment = add ? "plus" : "set";

        if( alpha ) {
            valLine = "                c."+assignment+"( i*c.numCols+j , alpha*total );\n";
        } else {
            valLine = "                c."+assignment+"( i*c.numCols+j , total );\n";
        }

        String foo =
                header + makeBoundsCheck(false,false, "b.numRows")+
                        "        for( int j = 0; j < b.numCols; j++ ) {\n" +
                        "            // create a copy of the column in B to avoid cache issues\n" +
                        "            for( int k = 0; k < b.numRows; k++ ) {\n" +
                        "                aux[k] = b.unsafe_get(k,j);\n" +
                        "            }\n" +
                        "\n" +
                        "            int indexA = 0;\n" +
                        "            for( int i = 0; i < a.numRows; i++ ) {\n" +
                        "                double total = 0;\n" +
                        "                for( int k = 0; k < b.numRows; ) {\n" +
                        "                    total += a.get(indexA++)*aux[k++];\n" +
                        "                }\n" +
                        valLine +
                        "            }\n" +
                        "        }\n" +
                        "    }\n";
        stream.print(foo);
    }

    public void printMultTransA_reorder( boolean alpha , boolean add ) {
        String header,valLine1,valLine2;

        header = makeHeader("mult","reorder",add,alpha, false, true,false);

        String assignment = add ? "plus" : "set";

        if( alpha ) {
            valLine1 = "valA = alpha*a.get(i);\n";
            valLine2 = "valA = alpha*a.unsafe_get(k,i);\n";
        } else {
            valLine1 = "valA = a.get(i);\n";
            valLine2 = "valA = a.unsafe_get(k,i);\n";
        }

        String foo =
                header + makeBoundsCheck(true,false, null)+
                        "        double valA;\n" +
                        "\n" +
                        "        for( int i = 0; i < a.numCols; i++ ) {\n" +
                        "            int indexC_start = i*c.numCols;\n" +
                        "\n" +
                        "            // first assign R\n" +
                        "            " +valLine1+
                        "            int indexB = 0;\n" +
                        "            int end = indexB+b.numCols;\n" +
                        "            int indexC = indexC_start;\n" +
                        "            while( indexB<end ) {\n" +
                        "                c."+assignment+"( indexC++ , valA*b.get(indexB++));\n" +
                        "            }\n" +
                        "            // now increment it\n" +
                        "            for( int k = 1; k < a.numRows; k++ ) {\n" +
                        "                " +valLine2+
                        "                end = indexB+b.numCols;\n" +
                        "                indexC = indexC_start;\n" +
                        "                // this is the loop for j\n" +
                        "                while( indexB<end ) {\n" +
                        "                    c.plus( indexC++ , valA*b.get(indexB++));\n" +
                        "                }\n" +
                        "            }\n" +
                        "        }\n" +
                        "    }\n";
        stream.print(foo);
    }

    public void printMultTransA_small( boolean alpha , boolean add ) {
        String header,valLine;

        header = makeHeader("mult","small",add,alpha, false, true,false);

        String assignment = add ? "plus" : "set";

        if( alpha ) {
            valLine = "c."+assignment+"( cIndex++ , alpha*total );\n";
        } else {
            valLine = "c."+assignment+"( cIndex++ , total );\n";
        }

        String foo =
                header + makeBoundsCheck(true,false, null)+
                        "        int cIndex = 0;\n" +
                        "\n" +
                        "        for( int i = 0; i < a.numCols; i++ ) {\n" +
                        "            for( int j = 0; j < b.numCols; j++ ) {\n" +
                        "                int indexA = i;\n" +
                        "                int indexB = j;\n" +
                        "                int end = indexB + b.numRows*b.numCols;\n" +
                        "\n" +
                        "                double total = 0;\n" +
                        "\n" +
                        "                // loop for k\n" +
                        "                for(; indexB < end; indexB += b.numCols ) {\n" +
                        "                    total += a.get(indexA) * b.get(indexB);\n" +
                        "                    indexA += a.numCols;\n" +
                        "                }\n" +
                        "\n" +
                        "                "+valLine +
                        "            }\n" +
                        "        }\n" +
                        "    }\n";

         stream.print(foo);
    }

    public void printMultTransB( boolean alpha , boolean add ) {
        String header,valLine;

        header = makeHeader("mult",null,add,alpha, false, false,true);

        String assignment = add ? "plus" : "set";

        if( alpha ) {
            valLine = "c."+assignment+"( cIndex++ , alpha*total );\n";
        } else {
            valLine = "c."+assignment+"( cIndex++ , total );\n";
        }

        String foo =
                header + makeBoundsCheck(false,true, null)+
                        "        int cIndex = 0;\n" +
                        "        int aIndexStart = 0;\n" +
                        "\n" +
                        "        for( int xA = 0; xA < a.numRows; xA++ ) {\n" +
                        "            int end = aIndexStart + b.numCols;\n" +
                        "            int indexB = 0;\n"+
                        "            for( int xB = 0; xB < b.numRows; xB++ ) {\n" +
                        "                int indexA = aIndexStart;\n" +
                        "\n" +
                        "                double total = 0;\n" +
                        "\n" +
                        "                while( indexA<end ) {\n" +
                        "                    total += a.get(indexA++) * b.get(indexB++);\n" +
                        "                }\n" +
                        "\n" +
                        "                "+valLine +
                        "            }\n" +
                        "            aIndexStart += a.numCols;\n" +
                        "        }\n" +
                        "    }\n";
        stream.print(foo);
    }

    public void printMultTransAB( boolean alpha , boolean add ) {
        String header,valLine;

        header = makeHeader("mult",null,add,alpha, false, true,true);

        String assignment = add ? "plus" : "set";

        if( alpha ) {
            valLine = "c."+assignment+"( cIndex++ , alpha*total );\n";
        } else {
            valLine = "c."+assignment+"( cIndex++ , total );\n";
        }

        String foo =
                header + makeBoundsCheck(true,true, null)+
                        "        int cIndex = 0;\n" +
                        "\n" +
                        "        for( int i = 0; i < a.numCols; i++ ) {\n" +
                        "            int indexB = 0;\n"+
                        "            for( int j = 0; j < b.numRows; j++ ) {\n" +
                        "                int indexA = i;\n" +
                        "                int end = indexB + b.numCols;\n" +
                        "\n" +
                        "                double total = 0;\n" +
                        "\n" +
                        "                for( ;indexB<end; ) {\n" +
                        "                    total += a.get(indexA) * b.get(indexB++);\n" +
                        "                    indexA += a.numCols;\n" +
                        "                }\n" +
                        "\n" +
                        "                "+valLine+
                        "            }\n" +
                        "        }\n"+
                        "    }\n";
        stream.print(foo);
    }

    public void printMultTransAB_aux( boolean alpha , boolean add ) {
        String header,valLine;

        header = makeHeader("mult","aux",add,alpha, true, true,true);

        String assignment = add ? "plus" : "set";

        if( alpha ) {
            valLine = "c."+assignment+"( indexC++ , alpha*total );\n";
        } else {
            valLine = "c."+assignment+"( indexC++ , total );\n";
        }

        String foo =
                header + makeBoundsCheck(true,true, "a.numRows")+
                        "        int indexC = 0;\n" +
                        "        for( int i = 0; i < a.numCols; i++ ) {\n" +
                        "            for( int k = 0; k < b.numCols; k++ ) {\n" +
                        "                aux[k] = a.unsafe_get(k,i);\n" +
                        "            }\n" +
                        "\n" +
                        "            for( int j = 0; j < b.numRows; j++ ) {\n" +
                        "                double total = 0;\n" +
                        "\n" +
                        "                for( int k = 0; k < b.numCols; k++ ) {\n" +
                        "                    total += aux[k] * b.unsafe_get(j,k);\n" +
                        "                }\n" +
                        "                "+valLine +
                        "            }\n" +
                        "        }\n"+
                        "    }\n";
        stream.print(foo);
    }

    public static void main( String args[] ) throws FileNotFoundException {
        GeneratorMatrixMatrixMult gen = new GeneratorMatrixMatrixMult("MatrixMatrixMult.java");

        gen.createClass();
    }
}
